package ip4region

import (
	"errors"
	"io/ioutil"
	"os"

	"gitee.com/RickieL/comm"
)

const (
	// IndexBlockLength IndexBlock长度
	IndexBlockLength int64 = 12

	// SuperBlockLength 超级块的长度
	SuperBlockLength int64 = 8

	// TotalHeaderLength TotalHeader 长度  16k,可以存储2048个index particon
	TotalHeaderLength int64 = 1024 * 16

	// DataStartPos 数据开始的位置
	DataStartPos int64 = SuperBlockLength + TotalHeaderLength
)

var err error
var ipInfo IPInfo

// IP4Region ip4region的结构体
type IP4Region struct {
	// db file handler
	dbFileHandler *os.File

	//header block info
	headerSip []int64
	headerPtr []int64
	headerLen int64

	// super block index info
	firstIndexPtr int64
	lastIndexPtr  int64
	totalBlocks   int64

	// for memory mode only
	// the original db binary string
	dbBinStr []byte
	dbFile   string
}

// IPInfo ipinfo的结构体
type IPInfo struct {
	Country  string
	Region   string
	Province string
	City     string
	ISP      string
}

// New 初始化ip4region
func New(path string) (*IP4Region, error) {

	file, err := os.Open(path)
	if err != nil {
		return nil, err
	}

	return &IP4Region{
		dbFile:        path,
		dbFileHandler: file,
	}, nil
}

// Close 关闭ip4region实例
func (ip4 *IP4Region) Close() {
	ip4.dbFileHandler.Close()
}

// MemorySearch 在内存进行搜索
func (ip4 *IP4Region) MemorySearch(ipStr string) (ipInfo IPInfo, err error) {
	ipInfo = IPInfo{}

	if len(ip4.dbBinStr) == 0 {
		ip4.dbBinStr, err = ioutil.ReadFile(ip4.dbFile)

		if err != nil {
			return ipInfo, err
		}

		ip4.firstIndexPtr = ip4.getLongLitte(0)
		ip4.lastIndexPtr = ip4.getLongLitte(4)
		ip4.totalBlocks = (ip4.lastIndexPtr-ip4.firstIndexPtr)/IndexBlockLength + 1
	}

	ip, err := comm.IP2Long(ipStr)
	if err != nil {
		return ipInfo, err
	}

	h := ip4.totalBlocks
	var dataPtr, l int64
	for l <= h {

		m := (l + h) >> 1
		p := ip4.firstIndexPtr + m*IndexBlockLength
		sip := ip4.getLongLitte(p)
		if ip < sip {
			h = m - 1
		} else {
			eip := ip4.getLongLitte(p + 4)
			if ip > eip {
				l = m + 1
			} else {
				dataPtr = ip4.getLong(p + 8)
				break
			}
		}
	}
	if dataPtr == 0 {
		return ipInfo, errors.New("not found")
	}

	dataLen := ((dataPtr >> 24) & 0xFF)
	dataPtr = (dataPtr & 0x00FFFFFF)
	ipInfo = getIPInfo(ip4.dbBinStr[(dataPtr) : dataPtr+dataLen])
	return ipInfo, nil
}

// BinarySearch 使用二分法进行搜索
func (ip4 *IP4Region) BinarySearch(ipStr string) (ipInfo IPInfo, err error) {
	ipInfo = IPInfo{}
	if ip4.totalBlocks == 0 {
		ip4.dbFileHandler.Seek(0, 0)
		superBlock := make([]byte, 8)
		ip4.dbFileHandler.Read(superBlock)
		ip4.firstIndexPtr = getLongLitte(superBlock, 0)
		ip4.lastIndexPtr = getLongLitte(superBlock, 4)
		ip4.totalBlocks = (ip4.lastIndexPtr-ip4.firstIndexPtr)/IndexBlockLength + 1
	}

	var l, dataPtr, p int64

	h := ip4.totalBlocks

	ip, err := comm.IP2Long(ipStr)

	if err != nil {
		return
	}

	for l <= h {
		m := (l + h) >> 1

		p = m * IndexBlockLength

		_, err = ip4.dbFileHandler.Seek(ip4.firstIndexPtr+p, 0)
		if err != nil {
			return
		}

		buffer := make([]byte, IndexBlockLength)
		_, err = ip4.dbFileHandler.Read(buffer)
		if err != nil {
			return
		}
		sip := getLongLitte(buffer, 0)
		if ip < sip {
			h = m - 1
		} else {
			eip := getLongLitte(buffer, 4)
			if ip > eip {
				l = m + 1
			} else {
				dataPtr = getLong(buffer, 8)
				break
			}
		}

	}

	if dataPtr == 0 {
		err = errors.New("not found")
		return
	}

	dataLen := ((dataPtr >> 24) & 0xFF)
	dataPtr = (dataPtr & 0x00FFFFFF)

	ip4.dbFileHandler.Seek(dataPtr, 0)
	data := make([]byte, dataLen)
	_, err = ip4.dbFileHandler.Read(data)
	ipInfo = getIPInfo(data)

	return
}

// BtreeSearch 使用btree的方法进行搜索
func (ip4 *IP4Region) BtreeSearch(ipStr string) (ipInfo IPInfo, err error) {
	ipInfo = IPInfo{}
	ip, err := comm.IP2Long(ipStr)

	if ip4.headerLen == 0 {
		ip4.dbFileHandler.Seek(8, 0)

		buffer := make([]byte, TotalHeaderLength)
		ip4.dbFileHandler.Read(buffer)
		var idx int64
		var i int64
		for i = 0; i < TotalHeaderLength; i += 8 {
			startIP := getLongLitte(buffer, int64(i))
			dataPar := getLongLitte(buffer, int64(i+4))
			if dataPar == 0 {
				break
			}

			ip4.headerSip = append(ip4.headerSip, startIP)
			ip4.headerPtr = append(ip4.headerPtr, dataPar)
			idx++
		}

		ip4.headerLen = idx
	}

	var l, sptr, eptr int64
	h := ip4.headerLen

	for l <= h {
		m := int64(l+h) >> 1
		if m < ip4.headerLen {

			if ip == ip4.headerSip[m] {
				if m > 0 {
					sptr = ip4.headerPtr[m-1]
					eptr = ip4.headerPtr[m]
				} else {
					sptr = ip4.headerPtr[m]
					eptr = ip4.headerPtr[m+1]
				}
				break
			}

			if ip < ip4.headerSip[m] {
				if m == 0 {
					sptr = ip4.headerPtr[m]
					eptr = ip4.headerPtr[m+1]
					break
				} else if ip > ip4.headerSip[m-1] {
					sptr = ip4.headerPtr[m-1]
					eptr = ip4.headerPtr[m]
					break
				}
				h = m - 1
			} else {
				if m == ip4.headerLen-1 {
					sptr = ip4.headerPtr[m-1]
					eptr = ip4.headerPtr[m]
					break
				} else if ip <= ip4.headerSip[m+1] {
					sptr = ip4.headerPtr[m]
					eptr = ip4.headerPtr[m+1]
					break
				}
				l = m + 1
			}
		}

	}

	if sptr == 0 {
		err = errors.New("not found")
		return
	}

	blockLen := eptr - sptr
	ip4.dbFileHandler.Seek(sptr, 0)
	index := make([]byte, blockLen+IndexBlockLength)
	ip4.dbFileHandler.Read(index)
	var dataptr int64
	h = blockLen / IndexBlockLength
	l = 0

	for l <= h {
		m := int64(l+h) >> 1
		p := m * IndexBlockLength
		sip := getLongLitte(index, p)
		if ip < sip {
			h = m - 1
		} else {
			eip := getLongLitte(index, p+4)
			if ip > eip {
				l = m + 1
			} else {
				dataptr = getLong(index, p+8)
				break
			}
		}
	}

	if dataptr == 0 {
		err = errors.New("not found")
		return
	}

	dataLen := (dataptr >> 24) & 0xFF
	dataPtr := dataptr & 0x00FFFFFF

	ip4.dbFileHandler.Seek(dataPtr, 0)
	data := make([]byte, dataLen)
	ip4.dbFileHandler.Read(data)
	ipInfo = getIPInfo(data)
	return
}
